# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from op_test import (
    convert_float_to_uint16,
    get_device_place,
    get_places,
    is_custom_device,
)

import paddle
from paddle import base
from paddle.framework import in_dynamic_mode, in_pir_mode

SUPPORTED_DTYPES = [
    bool,
    np.int8,
    np.int16,
    np.uint16,
    np.int32,
    np.int64,
    np.float16,
    np.float32,
    np.float64,
    np.complex64,
    np.complex128,
]

TEST_META_OP_DATA = [
    {'op_str': 'logical_and', 'binary_op': True},
    {'op_str': 'logical_or', 'binary_op': True},
    {'op_str': 'logical_xor', 'binary_op': True},
    {'op_str': 'logical_not', 'binary_op': False},
]

TEST_META_SHAPE_DATA = {
    'XDimLargerThanYDim1': {'x_shape': [2, 3, 4, 5], 'y_shape': [4, 5]},
    'XDimLargerThanYDim2': {'x_shape': [2, 3, 4, 5], 'y_shape': [4, 1]},
    'XDimLargerThanYDim3': {'x_shape': [2, 3, 4, 5], 'y_shape': [1, 4, 1]},
    'XDimLargerThanYDim4': {'x_shape': [2, 3, 4, 5], 'y_shape': [3, 4, 1]},
    'XDimLargerThanYDim5': {'x_shape': [2, 3, 1, 5], 'y_shape': [3, 1, 1]},
    'XDimLessThanYDim1': {'x_shape': [4, 1], 'y_shape': [2, 3, 4, 5]},
    'XDimLessThanYDim2': {'x_shape': [1, 4, 1], 'y_shape': [2, 3, 4, 5]},
    'XDimLessThanYDim3': {'x_shape': [3, 4, 1], 'y_shape': [2, 3, 4, 5]},
    'XDimLessThanYDim4': {'x_shape': [3, 1, 1], 'y_shape': [2, 3, 1, 5]},
    'XDimLessThanYDim5': {'x_shape': [4, 5], 'y_shape': [2, 3, 4, 5]},
    'Axis1InLargerDim': {'x_shape': [1, 4, 5], 'y_shape': [2, 3, 1, 5]},
    'EqualDim1': {'x_shape': [10, 7], 'y_shape': [10, 7]},
    'EqualDim2': {'x_shape': [1, 1, 4, 5], 'y_shape': [2, 3, 1, 5]},
    'ZeroDim1': {'x_shape': [], 'y_shape': []},
    'ZeroDim2': {'x_shape': [2, 3, 4, 5], 'y_shape': []},
    'ZeroDim3': {'x_shape': [], 'y_shape': [2, 3, 4, 5]},
}

TEST_META_WRONG_SHAPE_DATA = {
    'ErrorDim1': {'x_shape': [2, 3, 4, 5], 'y_shape': [3, 4]},
    'ErrorDim2': {'x_shape': [2, 3, 4, 5], 'y_shape': [4, 3]},
}


def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True):
    paddle.enable_static()
    startup_program = paddle.static.Program()
    main_program = paddle.static.Program()
    place = paddle.CPUPlace()
    if use_gpu and (paddle.is_compiled_with_cuda() or is_custom_device()):
        place = get_device_place()
    exe = paddle.static.Executor(place)
    with paddle.static.program_guard(main_program, startup_program):
        x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
        op = getattr(paddle, op_str)
        feed_list = {'x': x_np}
        if not binary_op:
            res = op(x)
        else:
            y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype)
            feed_list['y'] = y_np
            res = op(x, y)
        exe.run(startup_program)
        static_result = exe.run(main_program, feed=feed_list, fetch_list=[res])
    return static_result


def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True):
    place = paddle.CPUPlace()
    if use_gpu and (paddle.is_compiled_with_cuda() or is_custom_device()):
        place = get_device_place()
    paddle.disable_static(place)
    op = getattr(paddle, op_str)
    x = paddle.to_tensor(x_np, dtype=x_np.dtype)
    if not binary_op:
        dygraph_result = op(x)
    else:
        y = paddle.to_tensor(y_np, dtype=y_np.dtype)
        dygraph_result = op(x, y)
    return dygraph_result


def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True):
    place = paddle.CPUPlace()
    if use_gpu and (paddle.is_compiled_with_cuda() or is_custom_device()):
        place = get_device_place()
    paddle.disable_static(place)
    op = getattr(paddle, op_str)
    x = paddle.to_tensor(x_np, dtype=x_np.dtype)
    if not binary_op:
        dygraph_result = op(x)
    else:
        y = paddle.to_tensor(y_np, dtype=y_np.dtype)
        dygraph_result = op(x, y)
    return dygraph_result


def np_data_generator(np_shape, dtype, *args, **kwargs):
    if dtype == bool:
        return np.random.choice(a=[True, False], size=np_shape).astype(bool)
    elif dtype == np.uint16:
        x = np.random.uniform(0.0, 1.0, np_shape).astype(np.float32)
        return convert_float_to_uint16(x)
    elif dtype == np.complex64 or dtype == np.complex128:
        return np.random.normal(0, 1, np_shape).astype(dtype) + (
            1.0j * np.random.normal(0, 1, np_shape)
        ).astype(dtype)
    else:
        return np.random.normal(0, 1, np_shape).astype(dtype)


def test(unit_test, use_gpu=False, test_error=False):
    for op_data in TEST_META_OP_DATA:
        meta_data = dict(op_data)
        meta_data['use_gpu'] = use_gpu
        np_op = getattr(np, meta_data['op_str'])
        META_DATA = dict(TEST_META_SHAPE_DATA)
        if test_error:
            META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
        for shape_data in META_DATA.values():
            for data_type in SUPPORTED_DTYPES:
                if not (
                    (paddle.is_compiled_with_cuda() or is_custom_device())
                    and use_gpu
                ) and (data_type in [np.float16, np.uint16]):
                    continue
                meta_data['x_np'] = np_data_generator(
                    shape_data['x_shape'], dtype=data_type
                )
                meta_data['y_np'] = np_data_generator(
                    shape_data['y_shape'], dtype=data_type
                )
                if meta_data['binary_op'] and test_error:
                    # catch C++ Exception
                    unit_test.assertRaisesRegex(
                        ValueError,
                        r"\(InvalidArgument\) Broadcast dimension mismatch",
                        run_static,
                        **meta_data,
                    )
                    unit_test.assertRaisesRegex(
                        ValueError,
                        r"\(InvalidArgument\) Broadcast dimension mismatch",
                        run_dygraph,
                        **meta_data,
                    )
                    continue
                static_result = run_static(**meta_data)
                dygraph_result = run_dygraph(**meta_data)
                eager_result = run_eager(**meta_data)
                if meta_data['binary_op']:
                    np_result = np_op(meta_data['x_np'], meta_data['y_np'])
                else:
                    np_result = np_op(meta_data['x_np'])
                unit_test.assertTrue((static_result == np_result).all())
                unit_test.assertTrue(
                    (dygraph_result.numpy() == np_result).all()
                )
                unit_test.assertTrue((eager_result.numpy() == np_result).all())
            # add some corner case for complex datatype
            for complex_data_type in [np.complex64, np.complex128]:
                for x_data in (0 + 0j, 0 + 1j, 1 + 0j, 1 + 1j):
                    for y_data in (0 + 0j, 0 + 1j, 1 + 0j, 1 + 1j):
                        meta_data['x_np'] = (
                            x_data * np.ones(shape_data['x_shape'])
                        ).astype(complex_data_type)
                        meta_data['y_np'] = (
                            y_data * np.ones(shape_data['y_shape'])
                        ).astype(complex_data_type)
                        if meta_data['binary_op'] and test_error:
                            # catch C++ Exception
                            unit_test.assertRaisesRegex(
                                ValueError,
                                r"\(InvalidArgument\) Broadcast dimension mismatch",
                                run_static,
                                **meta_data,
                            )
                            unit_test.assertRaisesRegex(
                                ValueError,
                                r"\(InvalidArgument\) Broadcast dimension mismatch",
                                run_dygraph,
                                **meta_data,
                            )
                            continue
                        static_result = run_static(**meta_data)
                        dygraph_result = run_dygraph(**meta_data)
                        eager_result = run_eager(**meta_data)
                        if meta_data['binary_op']:
                            np_result = np_op(
                                meta_data['x_np'], meta_data['y_np']
                            )
                        else:
                            np_result = np_op(meta_data['x_np'])
                        unit_test.assertTrue((static_result == np_result).all())
                        unit_test.assertTrue(
                            (dygraph_result.numpy() == np_result).all()
                        )
                        unit_test.assertTrue(
                            (eager_result.numpy() == np_result).all()
                        )


def test_type_error(unit_test, use_gpu, type_str_map):
    def check_type(op_str, x, y, binary_op):
        op = getattr(paddle, op_str)
        # The C++ backend raises TypeError for invalid type promotion.
        error_type = TypeError
        if isinstance(x, np.ndarray):
            x = paddle.to_tensor(x)
            y = paddle.to_tensor(y)
            # Use TypeError for dygraph as well to be more specific.
            error_type = TypeError

        if binary_op:
            type_x = type_str_map['x']
            type_y = type_str_map['y']
            if type_x != type_y:
                floating_dtypes = {
                    np.float16,
                    np.float32,
                    np.float64,
                    np.uint16,
                }
                complex_dtypes = {np.complex64, np.complex128}

                is_x_fp = type_x in floating_dtypes
                is_y_fp = type_y in floating_dtypes
                is_x_complex = type_x in complex_dtypes
                is_y_complex = type_y in complex_dtypes

                # Type promotion is supported between floating-point numbers,
                # and between complex and real numbers.
                promotion_allowed = (
                    (is_x_fp and is_y_fp) or is_x_complex or is_y_complex
                )

                if not promotion_allowed:
                    unit_test.assertRaises(error_type, op, x=x, y=y)

            if not in_dynamic_mode():
                error_type = TypeError
                # Skip this test in PIR mode because the C++ backend has a known bug
                # of ignoring the `out` parameter, which prevents the TypeError.
                if not in_pir_mode():
                    unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
        else:
            if not in_dynamic_mode():
                error_type = TypeError
                if not in_pir_mode():
                    unit_test.assertRaises(error_type, op, x=x, out=1)

    place = paddle.CPUPlace()
    if use_gpu and (paddle.is_compiled_with_cuda() or is_custom_device()):
        place = get_device_place()
    for op_data in TEST_META_OP_DATA:
        if (
            (paddle.is_compiled_with_cuda() or is_custom_device())
            and use_gpu
            and (
                type_str_map['x'] in [np.float16, np.uint16]
                or type_str_map['y'] in [np.float16, np.uint16]
            )
        ):
            continue

        meta_data = dict(op_data)
        binary_op = meta_data['binary_op']

        paddle.disable_static(place)
        x = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['x'])
        y = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['y'])
        check_type(meta_data['op_str'], x, y, binary_op)

        paddle.enable_static()
        startup_program = paddle.static.Program()
        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program, startup_program):
            x = paddle.static.data(
                name='x', shape=[10], dtype=type_str_map['x']
            )
            y = paddle.static.data(
                name='y', shape=[10], dtype=type_str_map['y']
            )
            check_type(meta_data['op_str'], x, y, binary_op)


def type_map_factory():
    return [
        {'x': x_type, 'y': y_type}
        for x_type in SUPPORTED_DTYPES
        for y_type in SUPPORTED_DTYPES
    ]


class TestCPU(unittest.TestCase):
    def test(self):
        test(self)

    def test_error(self):
        test(self, False, True)

    def test_type_error(self):
        type_map_list = type_map_factory()
        for type_map in type_map_list:
            test_type_error(self, False, type_map)


class TestCUDA(unittest.TestCase):
    def test(self):
        test(self, True)

    def test_error(self):
        test(self, True, True)

    def test_type_error(self):
        type_map_list = type_map_factory()
        for type_map in type_map_list:
            test_type_error(self, True, type_map)


class TestLogicalOpsAPI_Compatibility(unittest.TestCase):
    def setUp(self):
        np.random.seed(123)
        paddle.enable_static()
        self.places = get_places()
        self.shape = [10, 20]
        self.dtype = 'bool'

    def test_dygraph_api_compatibility(self):
        paddle.disable_static()
        for op_info in TEST_META_OP_DATA:
            op_str = op_info['op_str']
            is_binary = op_info['binary_op']
            with self.subTest(op=op_str):
                np_input = np.random.choice([True, False], size=self.shape)
                x = paddle.to_tensor(np_input)
                paddle_op = getattr(paddle, op_str)
                ref_op = getattr(np, op_str)

                paddle_dygraph_out = []

                if is_binary:
                    np_other = np.random.choice([True, False], size=self.shape)
                    y = paddle.to_tensor(np_other)
                    # Position args (args)
                    paddle_dygraph_out.append(paddle_op(x, y))
                    # Keywords args (kwargs) for paddle
                    paddle_dygraph_out.append(paddle_op(x=x, y=y))
                    # Keywords args for torch
                    paddle_dygraph_out.append(paddle_op(input=x, other=y))
                    # Combined args and kwargs
                    paddle_dygraph_out.append(paddle_op(x, other=y))
                    # Tensor method args
                    paddle_dygraph_out.append(x.__getattribute__(op_str)(y))
                    # Tensor method kwargs
                    paddle_dygraph_out.append(
                        x.__getattribute__(op_str)(other=y)
                    )

                    # Test out
                    out_tensor = paddle.empty(self.shape, dtype=self.dtype)
                    paddle_op(x, y, out=out_tensor)
                    paddle_dygraph_out.append(out_tensor)

                    # Numpy reference out
                    ref_out = ref_op(np_input, np_other)
                else:  # Unary op (logical_not)
                    # Position args (args)
                    paddle_dygraph_out.append(paddle_op(x))
                    # Keywords args (kwargs) for paddle
                    paddle_dygraph_out.append(paddle_op(x=x))
                    # Keywords args for torch
                    paddle_dygraph_out.append(paddle_op(input=x))
                    # Tensor method args
                    paddle_dygraph_out.append(x.__getattribute__(op_str)())

                    # Test out
                    out_tensor = paddle.empty(self.shape, dtype=self.dtype)
                    paddle_op(x, out=out_tensor)
                    paddle_dygraph_out.append(out_tensor)

                    # Numpy reference out
                    ref_out = ref_op(np_input)

                # Check
                for out in paddle_dygraph_out:
                    np.testing.assert_equal(ref_out, out.numpy())

        paddle.enable_static()

    def test_static_api_compatibility(self):
        for op_info in TEST_META_OP_DATA:
            op_str = op_info['op_str']
            is_binary = op_info['binary_op']

            with self.subTest(op=op_str):
                np_input = np.random.choice([True, False], size=self.shape)
                ref_op = getattr(np, op_str)

                main = paddle.static.Program()
                startup = paddle.static.Program()
                with base.program_guard(main, startup):
                    x = paddle.static.data(
                        name="x", shape=self.shape, dtype=self.dtype
                    )
                    paddle_op = getattr(paddle, op_str)

                    fetch_list = []
                    feed_dict = {"x": np_input}

                    if is_binary:
                        np_other = np.random.choice(
                            [True, False], size=self.shape
                        )
                        y = paddle.static.data(
                            name="y", shape=self.shape, dtype=self.dtype
                        )
                        feed_dict["y"] = np_other

                        # Position args (args)
                        fetch_list.append(paddle_op(x, y))
                        # Keywords args (kwargs) for paddle
                        fetch_list.append(paddle_op(x=x, y=y))
                        # Keywords args for torch
                        fetch_list.append(paddle_op(input=x, other=y))
                        # Combined args and kwargs
                        fetch_list.append(paddle_op(x, other=y))
                        # Tensor method args
                        fetch_list.append(x.__getattribute__(op_str)(y))
                        # Tensor method kwargs
                        fetch_list.append(x.__getattribute__(op_str)(other=y))

                        # Numpy reference out
                        ref_out = ref_op(np_input, np_other)
                    else:  # Unary op
                        # Position args (args)
                        fetch_list.append(paddle_op(x))
                        # Keywords args (kwargs) for paddle
                        fetch_list.append(paddle_op(x=x))
                        # Keywords args for torch
                        fetch_list.append(paddle_op(input=x))
                        # Tensor method args
                        fetch_list.append(x.__getattribute__(op_str)())

                        # Numpy reference out
                        ref_out = ref_op(np_input)

                    for place in self.places:
                        exe = base.Executor(place)
                        outs = exe.run(
                            main, feed=feed_dict, fetch_list=fetch_list
                        )
                        # Check
                        for out in outs:
                            np.testing.assert_equal(ref_out, out)


if __name__ == '__main__':
    paddle.enable_static()
    unittest.main()
